$\newcommand{\vect}[1]{{\mathbf{\boldsymbol{#1}} }}$ $\newcommand{\amax}{{\text{argmax}}}$ $\newcommand{\P}{{\mathbb{P}}}$ $\newcommand{\E}{{\mathbb{E}}}$ $\newcommand{\R}{{\mathbb{R}}}$ $\newcommand{\Z}{{\mathbb{Z}}}$ $\newcommand{\N}{{\mathbb{N}}}$ $\newcommand{\C}{{\mathbb{C}}}$ $\newcommand{\abs}[1]{{ \left| #1 \right| }}$ $\newcommand{\simpl}[1]{{\Delta^{#1} }}$

Anomaly Detection via Reconstruction Error

Snow
import ipywidgets as widgets
import itertools as it
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px


from ipywidgets import interact
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.model_selection import RandomizedSearchCV
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import LabelBinarizer
from sklearn.ensemble import IsolationForest
from sklearn import metrics
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
from sklearn.neighbors import KernelDensity
from tensorflow import keras
from tqdm import tqdm

from tfl_training_anomaly_detection.exercise_tools import (
    evaluate, 
    get_kdd_data, 
    get_house_prices_data, 
    create_distributions, 
    contamination, 
    perform_rkde_experiment, 
    get_mnist_data
)
from tfl_training_anomaly_detection.vae import VAE, build_decoder_mnist, build_encoder_minst, build_contaminated_minst

%matplotlib inline
matplotlib.rcParams['figure.figsize'] = (5, 5)

Exercise: Anomaly Detection on the MNIST Data Set

MNIST is one of the most iconic data sets in the history of machine learning. It contains 70000 samples of $28\times 28$ grayscale images of handwritten digits. Because of its moderate complexity and good visualizability it is well suited to study the behavior of machine learning algorithms in higher dimensional spaces.

While originally created for classification (optical character recognition), we can build an anomaly detection data set by corrupting some of the images.

Pre-processing

We first need to obtain the MNIST data set and prepare an anomaly detection set from it. Note that the data set is n row vector format. Therefore, we work with $28\times 28 = 784$ dimensional data points.

# load MNIST Data Set
mnist = get_mnist_data()

data = mnist['data']
print('data.shape: {}'.format(data.shape))
target = mnist['target'].astype(int)
data.shape: (70000, 784)

Build contaminated Data Sets

We prepared a function that does the job for us. It corrupts a prescribed portion of the data by introducing a rotation, noise or a blackout of some part of the image.

First, we need to transform the data into image format.

X = data.reshape(-1, 28, 28, 1)/255

Train/Test-Split

We will only corrupt the test set, hence we will perform the train-test split beforehand. We separate a relatively small test set so that we can use as much as possible from the data to obtain high quality representations.

test_size = .1
X_train, X_test, target_train, target_test = train_test_split(X, target, test_size=test_size)
X_test, y_test = build_contaminated_minst(X_test)

# Visualize contamination
anomalies = X_test[y_test != 0]
selection = np.random.choice(len(anomalies), 25)

fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(5, 5))
for img, ax in zip(anomalies[selection], axes.flatten()):
    ax.imshow(img, 'gray')
    ax.axis('off')
plt.show()

Autoencoder

Let us finally train an autoencoder model. We replicate the model given in the Keras documentation and apply it in a synthetic outlier detection scenario based on MNIST.

in the vae package we provide the implementation of the VAE. Please take a look into the source code to see how the minimization of the KL divergence is implemented.

Create Model

latent_dim = 3
vae = VAE(decoder=build_decoder_mnist(latent_dim=latent_dim), encoder=build_encoder_minst(latent_dim=latent_dim))
## Inspect model architecture
vae.encoder.summary()
Model: "encoder"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
==================================================================================================
 input_2 (InputLayer)        [(None, 28, 28, 1)]          0         []                            
                                                                                                  
 conv2d (Conv2D)             (None, 14, 14, 32)           320       ['input_2[0][0]']             
                                                                                                  
 conv2d_1 (Conv2D)           (None, 7, 7, 64)             18496     ['conv2d[0][0]']              
                                                                                                  
 flatten (Flatten)           (None, 3136)                 0         ['conv2d_1[0][0]']            
                                                                                                  
 dense_1 (Dense)             (None, 16)                   50192     ['flatten[0][0]']             
                                                                                                  
 z_mean (Dense)              (None, 3)                    51        ['dense_1[0][0]']             
                                                                                                  
 z_log_var (Dense)           (None, 3)                    51        ['dense_1[0][0]']             
                                                                                                  
 sampling (Sampling)         (None, 3)                    0         ['z_mean[0][0]',              
                                                                     'z_log_var[0][0]']           
                                                                                                  
==================================================================================================
Total params: 69110 (269.96 KB)
Trainable params: 69110 (269.96 KB)
Non-trainable params: 0 (0.00 Byte)
__________________________________________________________________________________________________
## Inspect model architecture
vae.decoder.summary()
Model: "decoder"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 input_1 (InputLayer)        [(None, 3)]               0         
                                                                 
 dense (Dense)               (None, 3136)              12544     
                                                                 
 reshape (Reshape)           (None, 7, 7, 64)          0         
                                                                 
 conv2d_transpose (Conv2DTr  (None, 14, 14, 64)        36928     
 anspose)                                                        
                                                                 
 conv2d_transpose_1 (Conv2D  (None, 28, 28, 32)        18464     
 Transpose)                                                      
                                                                 
 conv2d_transpose_2 (Conv2D  (None, 28, 28, 1)         289       
 Transpose)                                                      
                                                                 
=================================================================
Total params: 68225 (266.50 KB)
Trainable params: 68225 (266.50 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
# train model
n_epochs = 30

vae.compile(optimizer=keras.optimizers.Adam(learning_rate=.001))
history = vae.fit(X_train, epochs=n_epochs, batch_size=128)
WARNING:absl:At this time, the v2.11+ optimizer `tf.keras.optimizers.Adam` runs slowly on M1/M2 Macs, please use the legacy Keras optimizer instead, located at `tf.keras.optimizers.legacy.Adam`.
Epoch 1/30
493/493 [==============================] - 13s 25ms/step - loss: 45.6633 - reconstruction_loss: 37.0972 - kl_loss: 0.0079
Epoch 2/30
493/493 [==============================] - 12s 25ms/step - loss: 33.7199 - reconstruction_loss: 33.6361 - kl_loss: 3.5348e-04
Epoch 3/30
493/493 [==============================] - 12s 25ms/step - loss: 33.5963 - reconstruction_loss: 33.6036 - kl_loss: 2.0530e-04
Epoch 4/30
493/493 [==============================] - 13s 25ms/step - loss: 33.6163 - reconstruction_loss: 33.5939 - kl_loss: 1.0495e-04
Epoch 5/30
493/493 [==============================] - 13s 27ms/step - loss: 33.6419 - reconstruction_loss: 33.5932 - kl_loss: 7.7540e-05
Epoch 6/30
493/493 [==============================] - 18s 36ms/step - loss: 33.6083 - reconstruction_loss: 33.5874 - kl_loss: 6.2611e-05
Epoch 7/30
493/493 [==============================] - 13s 27ms/step - loss: 33.6741 - reconstruction_loss: 33.5896 - kl_loss: 5.4438e-05
Epoch 8/30
493/493 [==============================] - 16s 33ms/step - loss: 33.6057 - reconstruction_loss: 33.5872 - kl_loss: 4.2172e-05
Epoch 9/30
493/493 [==============================] - 40s 82ms/step - loss: 33.6389 - reconstruction_loss: 33.5861 - kl_loss: 3.8024e-05
Epoch 10/30
493/493 [==============================] - 20s 40ms/step - loss: 33.5880 - reconstruction_loss: 33.5863 - kl_loss: 3.8647e-05
Epoch 11/30
493/493 [==============================] - 20s 41ms/step - loss: 33.6217 - reconstruction_loss: 33.5832 - kl_loss: 3.3977e-05
Epoch 12/30
493/493 [==============================] - 20s 41ms/step - loss: 33.6360 - reconstruction_loss: 33.5833 - kl_loss: 2.7641e-05
Epoch 13/30
493/493 [==============================] - 21s 42ms/step - loss: 33.6757 - reconstruction_loss: 33.5836 - kl_loss: 3.2303e-05
Epoch 14/30
493/493 [==============================] - 18s 36ms/step - loss: 33.6116 - reconstruction_loss: 33.5833 - kl_loss: 3.2345e-05
Epoch 15/30
493/493 [==============================] - 12s 25ms/step - loss: 33.6427 - reconstruction_loss: 33.5814 - kl_loss: 2.8013e-05
Epoch 16/30
493/493 [==============================] - 12s 25ms/step - loss: 33.6071 - reconstruction_loss: 33.5800 - kl_loss: 2.9247e-05
Epoch 17/30
493/493 [==============================] - 12s 25ms/step - loss: 33.6536 - reconstruction_loss: 33.5795 - kl_loss: 2.7215e-05
Epoch 18/30
493/493 [==============================] - 12s 25ms/step - loss: 33.6130 - reconstruction_loss: 33.5788 - kl_loss: 2.3973e-05
Epoch 19/30
493/493 [==============================] - 13s 27ms/step - loss: 33.6559 - reconstruction_loss: 33.5768 - kl_loss: 2.2127e-05
Epoch 20/30
493/493 [==============================] - 13s 27ms/step - loss: 33.5886 - reconstruction_loss: 33.5760 - kl_loss: 2.0363e-05
Epoch 21/30
493/493 [==============================] - 13s 26ms/step - loss: 33.6343 - reconstruction_loss: 33.5759 - kl_loss: 1.9129e-05
Epoch 22/30
493/493 [==============================] - 13s 26ms/step - loss: 33.6070 - reconstruction_loss: 33.5786 - kl_loss: 2.1504e-05
Epoch 23/30
493/493 [==============================] - 13s 26ms/step - loss: 33.6471 - reconstruction_loss: 33.5763 - kl_loss: 2.2993e-05
Epoch 24/30
493/493 [==============================] - 13s 26ms/step - loss: 33.6794 - reconstruction_loss: 33.5768 - kl_loss: 1.9568e-05
Epoch 25/30
493/493 [==============================] - 14s 29ms/step - loss: 33.6503 - reconstruction_loss: 33.5762 - kl_loss: 1.9934e-05
Epoch 26/30
493/493 [==============================] - 14s 28ms/step - loss: 33.6081 - reconstruction_loss: 33.5754 - kl_loss: 1.7368e-05
Epoch 27/30
493/493 [==============================] - 14s 28ms/step - loss: 33.6239 - reconstruction_loss: 33.5740 - kl_loss: 1.4350e-05
Epoch 28/30
493/493 [==============================] - 13s 26ms/step - loss: 33.6507 - reconstruction_loss: 33.5750 - kl_loss: 1.7213e-05
Epoch 29/30
493/493 [==============================] - 13s 25ms/step - loss: 33.6671 - reconstruction_loss: 33.5759 - kl_loss: 1.6418e-05
Epoch 30/30
493/493 [==============================] - 12s 25ms/step - loss: 33.6455 - reconstruction_loss: 33.5758 - kl_loss: 1.3022e-05

Inspect Result

import matplotlib.pyplot as plt


def plot_latent_space(vae: VAE, n: int=10, figsize: float=10):
    """Plot sample images from 2D slices of latent space
    
    @param vae: vae model
    @param n: sample nXn images per slice
    @param figsize: figure size
    
    """
    for perm in [[0, 1, 2], [1, 2, 0], [2, 1, 0]]:
        # display a n*n 2D manifold of digits
        digit_size = 28
        scale = 1.0
        figure = np.zeros((digit_size * n, digit_size * n))
        # linearly spaced coordinates corresponding to the 2D plot
        # of digit classes in the latent space
        grid_x = np.linspace(-scale, scale, n)
        grid_y = np.linspace(-scale, scale, n)[::-1]

        for i, yi in enumerate(grid_y):
            for j, xi in enumerate(grid_x):
                z_sample = np.array([[xi, yi, 0]])
                z_sample[0] = z_sample[0][perm]
                x_decoded = vae.decoder.predict(z_sample)
                digit = x_decoded[0].reshape(digit_size, digit_size)
                figure[
                    i * digit_size : (i + 1) * digit_size,
                    j * digit_size : (j + 1) * digit_size,
                ] = digit

        plt.figure(figsize=(figsize, figsize))
        start_range = digit_size // 2
        end_range = n * digit_size + start_range
        pixel_range = np.arange(start_range, end_range, digit_size)
        sample_range_x = np.round(grid_x, 1)
        sample_range_y = np.round(grid_y, 1)
        plt.xticks(pixel_range, sample_range_x)
        plt.yticks(pixel_range, sample_range_y)
        plt.xlabel("z[{}]".format(perm[0]))
        plt.ylabel("z[{}]".format(perm[1]))
        plt.gca().set_title('z[{}] = 0'.format(perm[2]))
        plt.imshow(figure, cmap="Greys_r")
        plt.show()
plot_latent_space(vae)
1/1 [==============================] - 0s 49ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 11ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 26ms/step
1/1 [==============================] - 0s 13ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 9ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 10ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 8ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 8ms/step
# Principal components
pca = PCA()
latents = vae.encoder.predict(X_train)[2]
pca.fit(latents)

kwargs = {'x_{}'.format(i): (-1., 1.) for i in range(latent_dim)}


@widgets.interact(**kwargs)
def explore_latent_space(**kwargs):
    """Widget to explore latent space from given start position
    """
    center_img = pca.transform(np.zeros([1,latent_dim]))

    latent_rep_pca =  center_img + np.array([[kwargs[key] for key in kwargs]])
    latent_rep = pca.inverse_transform(latent_rep_pca)
    img = vae.decoder(latent_rep).numpy().reshape(28, 28)

    fig, ax = plt.subplots()
    ax.axis('off')
    ax.axis('off')

    ax.imshow(img,cmap='gray', vmin=0, vmax=1)
    plt.show()
1969/1969 [==============================] - 1s 709us/step
latents = vae.encoder.predict(X_train)[2]
scatter = px.scatter_3d(x=latents[:, 0], y=latents[:, 1], z=latents[:, 2], color=target_train)

scatter.show()
1969/1969 [==============================] - 1s 754us/step
latents = vae.encoder.predict(X_test)[2]
scatter = px.scatter_3d(x=latents[:, 0], y=latents[:, 1], z=latents[:, 2], color=y_test)

scatter.show()
219/219 [==============================] - 0s 803us/step
X_test, X_val, y_test, y_val = train_test_split(X_test, y_test)
n_samples = 10

s = np.random.choice(range(len(X_val)), n_samples)
s = X_val[s]
#s = [X_train_img[i] for i in s]

fig, axes = plt.subplots(nrows=2, ncols=n_samples, figsize=(10, 2))
for img, ax_row in zip(s, axes.T):
    x = vae.decoder.predict(vae.encoder.predict(img.reshape(1, 28, 28, 1))[2]).reshape(28, 28)
    diff = x - img.reshape(28, 28)
    error = (diff * diff).sum()
    ax_row[0].axis('off')
    ax_row[1].axis('off')
    ax_row[0].imshow(img,cmap='gray', vmin=0, vmax=1)
    ax_row[1].imshow(x, cmap='gray', vmin=0, vmax=1)
    ax_row[1].set_title('E={:.1f}'.format(error))

plt.tight_layout()
plt.show()
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 6ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
1/1 [==============================] - 0s 7ms/step
from sklearn import metrics
y_test_bin = y_test.copy()
y_test_bin[y_test != 0] = 1
y_val_bin = y_val.copy()
y_val_bin[y_val != 0] = 1
# Evaluate
reconstruction = vae.decoder.predict(vae.encoder(X_val)[2])
rerrors = (reconstruction - X_val).reshape(-1, 28*28)
rerrors = (rerrors * rerrors).sum(axis=1)

# Let's calculate scores if any anomaly is present
if np.any(y_val_bin == 1):
    eval = evaluate(y_val_bin.astype(int), rerrors.astype(float))
    pr, rec, thr = eval['PR']
    f1s = (2 * ((pr * rec)[:-1]/(pr + rec)[:-1]))
    threshold = thr[np.argmax(f1s)]
    print('Optimal threshold: {}'.format(threshold))

    reconstruction = vae.decoder.predict(vae.encoder(X_test)[2])
    reconstruction_error = (reconstruction - X_test).reshape(-1, 28*28)
    reconstruction_error = (reconstruction_error * reconstruction_error).sum(axis=1)


    classification = (reconstruction_error > threshold).astype(int)

    print('Precision: {}'.format(metrics.precision_score(y_test_bin, classification)))
    print('Recall: {}'.format(metrics.recall_score(y_test_bin, classification)))
    print('F1: {}'.format(metrics.f1_score(y_test_bin, classification)))

    metrics.confusion_matrix(y_test_bin, classification)
else:
    reconstruction_error = None
55/55 [==============================] - 0s 3ms/step
Optimal threshold: 103.85885094252916
165/165 [==============================] - 0s 2ms/step
Precision: 0.4074074074074074
Recall: 0.20754716981132076
F1: 0.275

Sort Data by Reconstruction Error

if reconstruction_error is not None:
    combined = list(zip(X_test, reconstruction_error))
    combined.sort(key = lambda x: x[1])

Show Top Autoencoder Outliers

if reconstruction_error is not None:
    n_rows = 10
    n_cols = 10
    n_samples = n_rows*n_cols

    samples = [c[0] for c in combined[-n_samples:]]

    fig, axes = plt.subplots(nrows=n_rows, ncols=n_cols, figsize=(2*n_cols, 2*n_rows))
    for img, ax in zip(samples, axes.reshape(-1)):
        ax.axis('off')
        ax.imshow(img.reshape((28,28)), cmap='gray', vmin=0, vmax=1)

    plt.show()
Snow